{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\masaikk\\anaconda3\\envs\\mytorch\\lib\\site-packages\\gym\\envs\\classic_control\\cartpole.py:211: UserWarning: \u001B[33mWARN: You are calling render method without specifying any render mode. You can specify the render_mode at initialization, e.g. gym(\"CartPole-v0\", render_mode=\"rgb_array\")\u001B[0m\n",
      "  gym.logger.warn(\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "Image data of dtype object cannot be converted to float",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
      "Input \u001B[1;32mIn [2]\u001B[0m, in \u001B[0;36m<cell line: 16>\u001B[1;34m()\u001B[0m\n\u001B[0;32m     12\u001B[0m     plt\u001B[38;5;241m.\u001B[39mimshow(env\u001B[38;5;241m.\u001B[39mrender())\n\u001B[0;32m     13\u001B[0m     plt\u001B[38;5;241m.\u001B[39mshow()\n\u001B[1;32m---> 16\u001B[0m \u001B[43mshow\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n",
      "Input \u001B[1;32mIn [2]\u001B[0m, in \u001B[0;36mshow\u001B[1;34m()\u001B[0m\n\u001B[0;32m     11\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mshow\u001B[39m():\n\u001B[1;32m---> 12\u001B[0m     \u001B[43mplt\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mimshow\u001B[49m\u001B[43m(\u001B[49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrender\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     13\u001B[0m     plt\u001B[38;5;241m.\u001B[39mshow()\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\mytorch\\lib\\site-packages\\matplotlib\\_api\\deprecation.py:454\u001B[0m, in \u001B[0;36mmake_keyword_only.<locals>.wrapper\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m    448\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(args) \u001B[38;5;241m>\u001B[39m name_idx:\n\u001B[0;32m    449\u001B[0m     warn_deprecated(\n\u001B[0;32m    450\u001B[0m         since, message\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mPassing the \u001B[39m\u001B[38;5;132;01m%(name)s\u001B[39;00m\u001B[38;5;124m \u001B[39m\u001B[38;5;132;01m%(obj_type)s\u001B[39;00m\u001B[38;5;124m \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    451\u001B[0m         \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpositionally is deprecated since Matplotlib \u001B[39m\u001B[38;5;132;01m%(since)s\u001B[39;00m\u001B[38;5;124m; the \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    452\u001B[0m         \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter will become keyword-only \u001B[39m\u001B[38;5;132;01m%(removal)s\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m    453\u001B[0m         name\u001B[38;5;241m=\u001B[39mname, obj_type\u001B[38;5;241m=\u001B[39m\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter of \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfunc\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__name__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m()\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m--> 454\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\mytorch\\lib\\site-packages\\matplotlib\\pyplot.py:2623\u001B[0m, in \u001B[0;36mimshow\u001B[1;34m(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)\u001B[0m\n\u001B[0;32m   2617\u001B[0m \u001B[38;5;129m@_copy_docstring_and_deprecators\u001B[39m(Axes\u001B[38;5;241m.\u001B[39mimshow)\n\u001B[0;32m   2618\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mimshow\u001B[39m(\n\u001B[0;32m   2619\u001B[0m         X, cmap\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, norm\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, aspect\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, interpolation\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[0;32m   2620\u001B[0m         alpha\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, vmin\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, vmax\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, origin\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, extent\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m,\n\u001B[0;32m   2621\u001B[0m         interpolation_stage\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, filternorm\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m, filterrad\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m4.0\u001B[39m,\n\u001B[0;32m   2622\u001B[0m         resample\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, url\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, data\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m-> 2623\u001B[0m     __ret \u001B[38;5;241m=\u001B[39m \u001B[43mgca\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mimshow\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m   2624\u001B[0m \u001B[43m        \u001B[49m\u001B[43mX\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcmap\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcmap\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnorm\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnorm\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maspect\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43maspect\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m   2625\u001B[0m \u001B[43m        \u001B[49m\u001B[43minterpolation\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minterpolation\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43malpha\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43malpha\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mvmin\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvmin\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m   2626\u001B[0m \u001B[43m        \u001B[49m\u001B[43mvmax\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvmax\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43morigin\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43morigin\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mextent\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mextent\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m   2627\u001B[0m \u001B[43m        \u001B[49m\u001B[43minterpolation_stage\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minterpolation_stage\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m   2628\u001B[0m \u001B[43m        \u001B[49m\u001B[43mfilternorm\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mfilternorm\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfilterrad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mfilterrad\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mresample\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mresample\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m   2629\u001B[0m \u001B[43m        \u001B[49m\u001B[43murl\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43murl\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m{\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mdata\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[43mdata\u001B[49m\u001B[43m}\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43;01mif\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[43mdata\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;129;43;01mis\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;129;43;01mnot\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;28;43;01mNone\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;28;43;01melse\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[43m{\u001B[49m\u001B[43m}\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m   2630\u001B[0m \u001B[43m        \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m   2631\u001B[0m     sci(__ret)\n\u001B[0;32m   2632\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m __ret\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\mytorch\\lib\\site-packages\\matplotlib\\_api\\deprecation.py:454\u001B[0m, in \u001B[0;36mmake_keyword_only.<locals>.wrapper\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m    448\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(args) \u001B[38;5;241m>\u001B[39m name_idx:\n\u001B[0;32m    449\u001B[0m     warn_deprecated(\n\u001B[0;32m    450\u001B[0m         since, message\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mPassing the \u001B[39m\u001B[38;5;132;01m%(name)s\u001B[39;00m\u001B[38;5;124m \u001B[39m\u001B[38;5;132;01m%(obj_type)s\u001B[39;00m\u001B[38;5;124m \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    451\u001B[0m         \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpositionally is deprecated since Matplotlib \u001B[39m\u001B[38;5;132;01m%(since)s\u001B[39;00m\u001B[38;5;124m; the \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    452\u001B[0m         \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter will become keyword-only \u001B[39m\u001B[38;5;132;01m%(removal)s\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m    453\u001B[0m         name\u001B[38;5;241m=\u001B[39mname, obj_type\u001B[38;5;241m=\u001B[39m\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter of \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfunc\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__name__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m()\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m--> 454\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\mytorch\\lib\\site-packages\\matplotlib\\__init__.py:1423\u001B[0m, in \u001B[0;36m_preprocess_data.<locals>.inner\u001B[1;34m(ax, data, *args, **kwargs)\u001B[0m\n\u001B[0;32m   1420\u001B[0m \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[0;32m   1421\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21minner\u001B[39m(ax, \u001B[38;5;241m*\u001B[39margs, data\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m   1422\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m data \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m-> 1423\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[43max\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;28;43mmap\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43msanitize_sequence\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m   1425\u001B[0m     bound \u001B[38;5;241m=\u001B[39m new_sig\u001B[38;5;241m.\u001B[39mbind(ax, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m   1426\u001B[0m     auto_label \u001B[38;5;241m=\u001B[39m (bound\u001B[38;5;241m.\u001B[39marguments\u001B[38;5;241m.\u001B[39mget(label_namer)\n\u001B[0;32m   1427\u001B[0m                   \u001B[38;5;129;01mor\u001B[39;00m bound\u001B[38;5;241m.\u001B[39mkwargs\u001B[38;5;241m.\u001B[39mget(label_namer))\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\mytorch\\lib\\site-packages\\matplotlib\\axes\\_axes.py:5604\u001B[0m, in \u001B[0;36mAxes.imshow\u001B[1;34m(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)\u001B[0m\n\u001B[0;32m   5596\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mset_aspect(aspect)\n\u001B[0;32m   5597\u001B[0m im \u001B[38;5;241m=\u001B[39m mimage\u001B[38;5;241m.\u001B[39mAxesImage(\u001B[38;5;28mself\u001B[39m, cmap\u001B[38;5;241m=\u001B[39mcmap, norm\u001B[38;5;241m=\u001B[39mnorm,\n\u001B[0;32m   5598\u001B[0m                       interpolation\u001B[38;5;241m=\u001B[39minterpolation, origin\u001B[38;5;241m=\u001B[39morigin,\n\u001B[0;32m   5599\u001B[0m                       extent\u001B[38;5;241m=\u001B[39mextent, filternorm\u001B[38;5;241m=\u001B[39mfilternorm,\n\u001B[0;32m   5600\u001B[0m                       filterrad\u001B[38;5;241m=\u001B[39mfilterrad, resample\u001B[38;5;241m=\u001B[39mresample,\n\u001B[0;32m   5601\u001B[0m                       interpolation_stage\u001B[38;5;241m=\u001B[39minterpolation_stage,\n\u001B[0;32m   5602\u001B[0m                       \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m-> 5604\u001B[0m \u001B[43mim\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mset_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m   5605\u001B[0m im\u001B[38;5;241m.\u001B[39mset_alpha(alpha)\n\u001B[0;32m   5606\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m im\u001B[38;5;241m.\u001B[39mget_clip_path() \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m   5607\u001B[0m     \u001B[38;5;66;03m# image does not already have clipping set, clip to axes patch\u001B[39;00m\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\mytorch\\lib\\site-packages\\matplotlib\\image.py:701\u001B[0m, in \u001B[0;36m_ImageBase.set_data\u001B[1;34m(self, A)\u001B[0m\n\u001B[0;32m    697\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A \u001B[38;5;241m=\u001B[39m cbook\u001B[38;5;241m.\u001B[39msafe_masked_invalid(A, copy\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n\u001B[0;32m    699\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mdtype \u001B[38;5;241m!=\u001B[39m np\u001B[38;5;241m.\u001B[39muint8 \u001B[38;5;129;01mand\u001B[39;00m\n\u001B[0;32m    700\u001B[0m         \u001B[38;5;129;01mnot\u001B[39;00m np\u001B[38;5;241m.\u001B[39mcan_cast(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mdtype, \u001B[38;5;28mfloat\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124msame_kind\u001B[39m\u001B[38;5;124m\"\u001B[39m)):\n\u001B[1;32m--> 701\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mImage data of dtype \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m cannot be converted to \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    702\u001B[0m                     \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mfloat\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mdtype))\n\u001B[0;32m    704\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m3\u001B[39m \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m] \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[0;32m    705\u001B[0m     \u001B[38;5;66;03m# If just one dimension assume scalar and apply colormap\u001B[39;00m\n\u001B[0;32m    706\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A[:, :, \u001B[38;5;241m0\u001B[39m]\n",
      "\u001B[1;31mTypeError\u001B[0m: Image data of dtype object cannot be converted to float"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 640x480 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGiCAYAAACGUJO6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbB0lEQVR4nO3df0zd1f3H8RfQcqmx0DrGhbKrrHX+tqWCZVgb53IniQbXPxaZNYURf0xlRnuz2WJbUKulq7Yjs2hj1ekfOqpGjbEEp0xiVJZGWhKdbU2lFWa8tyWu3I4qtNzz/WPfXocFywf50bc8H8nnD84+537OPWH36b2995LgnHMCAMCYxIleAAAAI0HAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACZ5Dtjbb7+t4uJizZo1SwkJCXrllVdOOqe5uVmXXHKJfD6fzj77bD399NMjWCoAAF/zHLCenh7NmzdPdXV1wzp/3759uuaaa3TllVeqra1Nd911l2666Sa9/vrrnhcLAMBxCd/ly3wTEhL08ssva/HixUOes3z5cm3btk0ffvhhfOzXv/61Dh06pMbGxpFeGgAwyU0Z6wu0tLQoGAwOGCsqKtJdd9015Jze3l719vbGf47FYvriiy/0gx/8QAkJCWO1VADAGHDO6fDhw5o1a5YSE0fvrRdjHrBwOCy/3z9gzO/3KxqN6ssvv9S0adNOmFNTU6P77rtvrJcGABhHnZ2d+tGPfjRqtzfmARuJyspKhUKh+M/d3d0688wz1dnZqdTU1AlcGQDAq2g0qkAgoOnTp4/q7Y55wDIzMxWJRAaMRSIRpaamDvrsS5J8Pp98Pt8J46mpqQQMAIwa7X8CGvPPgRUWFqqpqWnA2BtvvKHCwsKxvjQA4HvMc8D+85//qK2tTW1tbZL++zb5trY2dXR0SPrvy3+lpaXx82+99Va1t7fr7rvv1u7du/Xoo4/q+eef17Jly0bnHgAAJiXPAXv//fc1f/58zZ8/X5IUCoU0f/58VVVVSZI+//zzeMwk6cc//rG2bdumN954Q/PmzdOGDRv0xBNPqKioaJTuAgBgMvpOnwMbL9FoVGlpaeru7ubfwADAmLF6DOe7EAEAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYNKIAlZXV6ecnBylpKSooKBA27dv/9bza2trde6552ratGkKBAJatmyZvvrqqxEtGAAAaQQB27p1q0KhkKqrq7Vjxw7NmzdPRUVFOnDgwKDnP/fcc1qxYoWqq6u1a9cuPfnkk9q6davuueee77x4AMDk5TlgGzdu1M0336zy8nJdcMEF2rx5s0477TQ99dRTg57/3nvvaeHChVqyZIlycnJ01VVX6frrrz/pszYAAL6Np4D19fWptbVVwWDw6xtITFQwGFRLS8ugcy677DK1trbGg9Xe3q6GhgZdffXVQ16nt7dX0Wh0wAEAwP+a4uXkrq4u9ff3y+/3Dxj3+/3avXv3oHOWLFmirq4uXX755XLO6dixY7r11lu/9SXEmpoa3XfffV6WBgCYZMb8XYjNzc1au3atHn30Ue3YsUMvvfSStm3bpjVr1gw5p7KyUt3d3fGjs7NzrJcJADDG0zOw9PR0JSUlKRKJDBiPRCLKzMwcdM7q1au1dOlS3XTTTZKkiy++WD09Pbrlllu0cuVKJSae2FCfzyefz+dlaQCAScbTM7Dk5GTl5eWpqakpPhaLxdTU1KTCwsJB5xw5cuSESCUlJUmSnHNe1wsAgCSPz8AkKRQKqaysTPn5+VqwYIFqa2vV09Oj8vJySVJpaamys7NVU1MjSSouLtbGjRs1f/58FRQUaO/evVq9erWKi4vjIQMAwCvPASspKdHBgwdVVVWlcDis3NxcNTY2xt/Y0dHRMeAZ16pVq5SQkKBVq1bps88+0w9/+EMVFxfrwQcfHL17AQCYdBKcgdfxotGo0tLS1N3drdTU1IleDgDAg7F6DOe7EAEAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYNKIAlZXV6ecnBylpKSooKBA27dv/9bzDx06pIqKCmVlZcnn8+mcc85RQ0PDiBYMAIAkTfE6YevWrQqFQtq8ebMKCgpUW1uroqIi7dmzRxkZGSec39fXp1/84hfKyMjQiy++qOzsbH366aeaMWPGaKwfADBJJTjnnJcJBQUFuvTSS7Vp0yZJUiwWUyAQ0B133KEVK1accP7mzZv10EMPaffu3Zo6deqIFhmNRpWWlqbu7m6lpqaO6DYAABNjrB7DPb2E2NfXp9bWVgWDwa9vIDFRwWBQLS0tg8559dVXVVhYqIqKCvn9fl100UVau3at+vv7h7xOb2+votHogAMAgP/lKWBdXV3q7++X3+8fMO73+xUOhwed097erhdffFH9/f1qaGjQ6tWrtWHDBj3wwANDXqempkZpaWnxIxAIeFkmAGASGPN3IcZiMWVkZOjxxx9XXl6eSkpKtHLlSm3evHnIOZWVleru7o4fnZ2dY71MAIAxnt7EkZ6erqSkJEUikQHjkUhEmZmZg87JysrS1KlTlZSUFB87//zzFQ6H1dfXp+Tk5BPm+Hw++Xw+L0sDAEwynp6BJScnKy8vT01NTfGxWCympqYmFRYWDjpn4cKF2rt3r2KxWHzs448/VlZW1qDxAgBgODy/hBgKhbRlyxY988wz2rVrl2677Tb19PSovLxcklRaWqrKysr4+bfddpu++OIL3Xnnnfr444+1bds2rV27VhUVFaN3LwAAk47nz4GVlJTo4MGDqqqqUjgcVm5urhobG+Nv7Ojo6FBi4tddDAQCev3117Vs2TLNnTtX2dnZuvPOO7V8+fLRuxcAgEnH8+fAJgKfAwMAu06Jz4EBAHCqIGAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADApBEFrK6uTjk5OUpJSVFBQYG2b98+rHn19fVKSEjQ4sWLR3JZAADiPAds69atCoVCqq6u1o4dOzRv3jwVFRXpwIED3zpv//79+v3vf69FixaNeLEAABznOWAbN27UzTffrPLycl1wwQXavHmzTjvtND311FNDzunv79cNN9yg++67T7Nnzz7pNXp7exWNRgccAAD8L08B6+vrU2trq4LB4Nc3kJioYDColpaWIefdf//9ysjI0I033jis69TU1CgtLS1+BAIBL8sEAEwCngLW1dWl/v5++f3+AeN+v1/hcHjQOe+8846efPJJbdmyZdjXqaysVHd3d/zo7Oz0skwAwCQwZSxv/PDhw1q6dKm2bNmi9PT0Yc/z+Xzy+XxjuDIAgHWeApaenq6kpCRFIpEB45FIRJmZmSec/8knn2j//v0qLi6Oj8Visf9eeMoU7dmzR3PmzBnJugEAk5ynlxCTk5OVl5enpqam+FgsFlNTU5MKCwtPOP+8887TBx98oLa2tvhx7bXX6sorr1RbWxv/tgUAGDHPLyGGQiGVlZUpPz9fCxYsUG1trXp6elReXi5JKi0tVXZ2tmpqapSSkqKLLrpowPwZM2ZI0gnjAAB44TlgJSUlOnjwoKqqqhQOh5Wbm6vGxsb4Gzs6OjqUmMgXfAAAxlaCc85N9CJOJhqNKi0tTd3d3UpNTZ3o5QAAPBirx3CeKgEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwKQRBayurk45OTlKSUlRQUGBtm/fPuS5W7Zs0aJFizRz5kzNnDlTwWDwW88HAGA4PAds69atCoVCqq6u1o4dOzRv3jwVFRXpwIEDg57f3Nys66+/Xm+99ZZaWloUCAR01VVX6bPPPvvOiwcATF4JzjnnZUJBQYEuvfRSbdq0SZIUi8UUCAR0xx13aMWKFSed39/fr5kzZ2rTpk0qLS0d9Jze3l719vbGf45GowoEAuru7lZqaqqX5QIAJlg0GlVaWtqoP4Z7egbW19en1tZWBYPBr28gMVHBYFAtLS3Duo0jR47o6NGjOuOMM4Y8p6amRmlpafEjEAh4WSYAYBLwFLCuri719/fL7/cPGPf7/QqHw8O6jeXLl2vWrFkDIvhNlZWV6u7ujh+dnZ1elgkAmASmjOfF1q1bp/r6ejU3NyslJWXI83w+n3w+3ziuDABgjaeApaenKykpSZFIZMB4JBJRZmbmt859+OGHtW7dOr355puaO3eu95UCAPA/PL2EmJycrLy8PDU1NcXHYrGYmpqaVFhYOOS89evXa82aNWpsbFR+fv7IVwsAwP/z/BJiKBRSWVmZ8vPztWDBAtXW1qqnp0fl5eWSpNLSUmVnZ6umpkaS9Mc//lFVVVV67rnnlJOTE/+3stNPP12nn376KN4VAMBk4jlgJSUlOnjwoKqqqhQOh5Wbm6vGxsb4Gzs6OjqUmPj1E7vHHntMfX19+tWvfjXgdqqrq3Xvvfd+t9UDACYtz58Dmwhj9RkCAMDYOyU+BwYAwKmCgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTRhSwuro65eTkKCUlRQUFBdq+ffu3nv/CCy/ovPPOU0pKii6++GI1NDSMaLEAABznOWBbt25VKBRSdXW1duzYoXnz5qmoqEgHDhwY9Pz33ntP119/vW688Ubt3LlTixcv1uLFi/Xhhx9+58UDACavBOec8zKhoKBAl156qTZt2iRJisViCgQCuuOOO7RixYoTzi8pKVFPT49ee+21+NhPf/pT5ebmavPmzYNeo7e3V729vfGfu7u7deaZZ6qzs1OpqalelgsAmGDRaFSBQECHDh1SWlra6N2w86C3t9clJSW5l19+ecB4aWmpu/baawedEwgE3J/+9KcBY1VVVW7u3LlDXqe6utpJ4uDg4OD4Hh2ffPKJl+Sc1BR50NXVpf7+fvn9/gHjfr9fu3fvHnROOBwe9PxwODzkdSorKxUKheI/Hzp0SGeddZY6OjpGt97fM8f/K4dnqt+OfTo59mh42KfhOf4q2hlnnDGqt+spYOPF5/PJ5/OdMJ6WlsYvyTCkpqayT8PAPp0cezQ87NPwJCaO7hvfPd1aenq6kpKSFIlEBoxHIhFlZmYOOiczM9PT+QAADIengCUnJysvL09NTU3xsVgspqamJhUWFg46p7CwcMD5kvTGG28MeT4AAMPh+SXEUCiksrIy5efna8GCBaqtrVVPT4/Ky8slSaWlpcrOzlZNTY0k6c4779QVV1yhDRs26JprrlF9fb3ef/99Pf7448O+ps/nU3V19aAvK+Jr7NPwsE8nxx4ND/s0PGO1T57fRi9JmzZt0kMPPaRwOKzc3Fz9+c9/VkFBgSTpZz/7mXJycvT000/Hz3/hhRe0atUq7d+/Xz/5yU+0fv16XX311aN2JwAAk8+IAgYAwETjuxABACYRMACASQQMAGASAQMAmHTKBIw/0TI8XvZpy5YtWrRokWbOnKmZM2cqGAyedF+/D7z+Lh1XX1+vhIQELV68eGwXeIrwuk+HDh1SRUWFsrKy5PP5dM4550yK/9953afa2lqde+65mjZtmgKBgJYtW6avvvpqnFY7Md5++20VFxdr1qxZSkhI0CuvvHLSOc3Nzbrkkkvk8/l09tlnD3jn+rCN6jcrjlB9fb1LTk52Tz31lPvnP//pbr75ZjdjxgwXiUQGPf/dd991SUlJbv369e6jjz5yq1atclOnTnUffPDBOK98fHndpyVLlri6ujq3c+dOt2vXLveb3/zGpaWluX/961/jvPLx43WPjtu3b5/Lzs52ixYtcr/85S/HZ7ETyOs+9fb2uvz8fHf11Ve7d955x+3bt881Nze7tra2cV75+PK6T88++6zz+Xzu2Wefdfv27XOvv/66y8rKcsuWLRvnlY+vhoYGt3LlSvfSSy85SSd84fs3tbe3u9NOO82FQiH30UcfuUceecQlJSW5xsZGT9c9JQK2YMECV1FREf+5v7/fzZo1y9XU1Ax6/nXXXeeuueaaAWMFBQXut7/97Ziuc6J53advOnbsmJs+fbp75plnxmqJE24ke3Ts2DF32WWXuSeeeMKVlZVNioB53afHHnvMzZ492/X19Y3XEk8JXvepoqLC/fznPx8wFgqF3MKFC8d0naeS4QTs7rvvdhdeeOGAsZKSEldUVOTpWhP+EmJfX59aW1sVDAbjY4mJiQoGg2ppaRl0TktLy4DzJamoqGjI878PRrJP33TkyBEdPXp01L8R+lQx0j26//77lZGRoRtvvHE8ljnhRrJPr776qgoLC1VRUSG/36+LLrpIa9euVX9//3gte9yNZJ8uu+wytba2xl9mbG9vV0NDA1/c8A2j9Rg+4d9GP15/osW6kezTNy1fvlyzZs064Rfn+2Ike/TOO+/oySefVFtb2zis8NQwkn1qb2/X3//+d91www1qaGjQ3r17dfvtt+vo0aOqrq4ej2WPu5Hs05IlS9TV1aXLL79czjkdO3ZMt956q+65557xWLIZQz2GR6NRffnll5o2bdqwbmfCn4FhfKxbt0719fV6+eWXlZKSMtHLOSUcPnxYS5cu1ZYtW5Senj7RyzmlxWIxZWRk6PHHH1deXp5KSkq0cuXKIf+q+mTV3NystWvX6tFHH9WOHTv00ksvadu2bVqzZs1EL+17acKfgfEnWoZnJPt03MMPP6x169bpzTff1Ny5c8dymRPK6x598skn2r9/v4qLi+NjsVhMkjRlyhTt2bNHc+bMGdtFT4CR/C5lZWVp6tSpSkpKio+df/75CofD6uvrU3Jy8piueSKMZJ9Wr16tpUuX6qabbpIkXXzxxerp6dEtt9yilStXjvrfw7JqqMfw1NTUYT/7kk6BZ2D8iZbhGck+SdL69eu1Zs0aNTY2Kj8/fzyWOmG87tF5552nDz74QG1tbfHj2muv1ZVXXqm2tjYFAoHxXP64Gcnv0sKFC7V379544CXp448/VlZW1vcyXtLI9unIkSMnROp49B1fOxs3ao/h3t5fMjbq6+udz+dzTz/9tPvoo4/cLbfc4mbMmOHC4bBzzrmlS5e6FStWxM9/99133ZQpU9zDDz/sdu3a5aqrqyfN2+i97NO6detccnKye/HFF93nn38ePw4fPjxRd2HMed2jb5os70L0uk8dHR1u+vTp7ne/+53bs2ePe+2111xGRoZ74IEHJuoujAuv+1RdXe2mT5/u/vrXv7r29nb3t7/9zc2ZM8ddd911E3UXxsXhw4fdzp073c6dO50kt3HjRrdz50736aefOuecW7FihVu6dGn8/ONvo//DH/7gdu3a5erq6uy+jd455x555BF35plnuuTkZLdgwQL3j3/8I/6/XXHFFa6srGzA+c8//7w755xzXHJysrvwwgvdtm3bxnnFE8PLPp111llO0glHdXX1+C98HHn9XfpfkyVgznnfp/fee88VFBQ4n8/nZs+e7R588EF37NixcV71+POyT0ePHnX33nuvmzNnjktJSXGBQMDdfvvt7t///vf4L3wcvfXWW4M+1hzfm7KyMnfFFVecMCc3N9clJye72bNnu7/85S+er8ufUwEAmDTh/wYGAMBIEDAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGDS/wFzTP77mPX4nAAAAABJRU5ErkJggg==\n"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "#创建环境\n",
    "env = gym.make('CartPole-v0')\n",
    "env.reset()\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态\n",
      "state= (array([ 0.02867511, -0.02715833, -0.00734932, -0.04049598], dtype=float32), {})\n",
      "这个游戏一共有2个动作,不是0就是1\n",
      "env.action_space= Discrete(2)\n",
      "随机一个动作\n",
      "action= 0\n",
      "执行一个动作,得到下一个状态,奖励,是否结束\n",
      "state= [ 0.02813194 -0.22217412 -0.00815924  0.24985914]\n",
      "reward= 1.0\n",
      "over= False\n"
     ]
    }
   ],
   "source": [
    "#测试游戏环境\n",
    "def test_env():\n",
    "    state = env.reset()\n",
    "    print('这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态')\n",
    "    print('state=', state)\n",
    "    #state= [ 0.03490619  0.04873464  0.04908862 -0.00375859]\n",
    "\n",
    "    print('这个游戏一共有2个动作,不是0就是1')\n",
    "    print('env.action_space=', env.action_space)\n",
    "    #env.action_space= Discrete(2)\n",
    "\n",
    "    print('随机一个动作')\n",
    "    action = env.action_space.sample()\n",
    "    print('action=', action)\n",
    "    #action= 1\n",
    "\n",
    "    print('执行一个动作,得到下一个状态,奖励,是否结束')\n",
    "    state, reward, over,_, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    #state= [ 0.02018229 -0.16441101  0.01547085  0.2661691 ]\n",
    "\n",
    "    print('reward=', reward)\n",
    "    #reward= 1.0\n",
    "\n",
    "    print('over=', over)\n",
    "    #over= False\n",
    "\n",
    "\n",
    "test_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "(Sequential(\n   (0): Linear(in_features=4, out_features=128, bias=True)\n   (1): ReLU()\n   (2): Linear(in_features=128, out_features=2, bias=True)\n ),\n Sequential(\n   (0): Linear(in_features=4, out_features=128, bias=True)\n   (1): ReLU()\n   (2): Linear(in_features=128, out_features=2, bias=True)\n ))"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "1"
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#得到一个动作\n",
    "def get_action(state):\n",
    "    if random.random() < 0.01:\n",
    "        return random.choice([0, 1])\n",
    "\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 4)\n",
    "\n",
    "    return model(state).argmax().item()\n",
    "\n",
    "\n",
    "get_action([0.0013847, -0.01194451, 0.04260966, 0.00688801])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\masaikk\\AppData\\Local\\Temp\\ipykernel_15052\\2390167335.py:10: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\torch\\csrc\\utils\\tensor_new.cpp:233.)\n",
      "  state = torch.FloatTensor(state).reshape(1, 4)\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "expected sequence of length 4 at dim 1 (got 0)",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mValueError\u001B[0m                                Traceback (most recent call last)",
      "Input \u001B[1;32mIn [7]\u001B[0m, in \u001B[0;36m<cell line: 39>\u001B[1;34m()\u001B[0m\n\u001B[0;32m     34\u001B[0m         datas\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;241m0\u001B[39m)\n\u001B[0;32m     36\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m update_count, drop_count\n\u001B[1;32m---> 39\u001B[0m \u001B[43mupdate_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m, \u001B[38;5;28mlen\u001B[39m(datas)\n",
      "Input \u001B[1;32mIn [7]\u001B[0m, in \u001B[0;36mupdate_data\u001B[1;34m()\u001B[0m\n\u001B[0;32m     15\u001B[0m over \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m     16\u001B[0m \u001B[38;5;28;01mwhile\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m over:\n\u001B[0;32m     17\u001B[0m     \u001B[38;5;66;03m#根据当前状态得到一个动作\u001B[39;00m\n\u001B[1;32m---> 18\u001B[0m     action \u001B[38;5;241m=\u001B[39m \u001B[43mget_action\u001B[49m\u001B[43m(\u001B[49m\u001B[43mstate\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     20\u001B[0m     \u001B[38;5;66;03m#执行动作,得到反馈\u001B[39;00m\n\u001B[0;32m     21\u001B[0m     next_state, reward, over, _ \u001B[38;5;241m=\u001B[39m env\u001B[38;5;241m.\u001B[39mstep(action)\n",
      "Input \u001B[1;32mIn [6]\u001B[0m, in \u001B[0;36mget_action\u001B[1;34m(state)\u001B[0m\n\u001B[0;32m      7\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m random\u001B[38;5;241m.\u001B[39mchoice([\u001B[38;5;241m0\u001B[39m, \u001B[38;5;241m1\u001B[39m])\n\u001B[0;32m      9\u001B[0m \u001B[38;5;66;03m#走神经网络,得到一个动作\u001B[39;00m\n\u001B[1;32m---> 10\u001B[0m state \u001B[38;5;241m=\u001B[39m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mFloatTensor\u001B[49m\u001B[43m(\u001B[49m\u001B[43mstate\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241m.\u001B[39mreshape(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m4\u001B[39m)\n\u001B[0;32m     12\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m model(state)\u001B[38;5;241m.\u001B[39margmax()\u001B[38;5;241m.\u001B[39mitem()\n",
      "\u001B[1;31mValueError\u001B[0m: expected sequence of length 4 at dim 1 (got 0)"
     ]
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 10000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 10000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/ipykernel_launcher.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  /opt/conda/conda-bld/pytorch_1640811701593/work/torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  import sys\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-1.1208e+00, -5.8698e-01, -1.6706e-02, -5.9305e-02],\n",
       "         [-4.8614e-03,  1.4742e-01, -3.7731e-03, -2.1355e-01],\n",
       "         [ 1.1511e-03,  1.4817e-01, -1.2478e-02, -2.3001e-01],\n",
       "         [-5.1451e-01, -3.9776e-01, -4.3573e-03, -2.2192e-01],\n",
       "         [-3.9565e-01, -3.9709e-01,  1.4259e-02, -2.3670e-01],\n",
       "         [-2.2995e-02,  1.4952e-01,  2.4450e-02, -2.5993e-01],\n",
       "         [-1.2419e+00, -7.8057e-01, -7.2140e-03,  1.9946e-01],\n",
       "         [-1.2029e+00, -5.8577e-01, -7.7445e-03, -8.6094e-02],\n",
       "         [ 8.6021e-05, -4.7436e-02, -1.0830e-02,  7.3364e-02],\n",
       "         [-4.1544e-01, -3.9743e-01,  1.0734e-02, -2.2922e-01],\n",
       "         [-1.0109e+00, -3.9503e-01, -1.8839e-02, -2.8212e-01],\n",
       "         [-6.5669e-01, -5.9022e-01, -9.2182e-03,  1.2090e-02],\n",
       "         [ 1.3677e-02,  1.5247e-01, -4.0947e-02, -3.2515e-01],\n",
       "         [-5.2246e-01, -5.9282e-01, -8.7956e-03,  6.9389e-02],\n",
       "         [-4.3524e-01, -3.9767e-01,  7.4865e-03, -2.2390e-01],\n",
       "         [-1.5043e-01, -5.9863e-01, -5.1039e-02,  1.9790e-01],\n",
       "         [-4.9469e-01, -3.9786e-01, -1.4163e-03, -2.1964e-01],\n",
       "         [-1.4097e+00, -3.8895e-01, -2.9505e-02, -4.1645e-01],\n",
       "         [-2.1816e-01, -5.9453e-01, -3.7744e-02,  1.0713e-01],\n",
       "         [-7.7063e-01, -7.8437e-01,  2.3514e-03,  2.8326e-01],\n",
       "         [-2.8658e-03,  1.4759e-01, -6.4852e-03, -2.1727e-01],\n",
       "         [-7.3137e-01, -5.8930e-01, -2.9467e-03, -8.3129e-03],\n",
       "         [-3.9059e-03, -4.7775e-02, -5.3898e-03,  8.0832e-02],\n",
       "         [-1.0306e+00, -3.9441e-01, -2.4390e-02, -2.9574e-01],\n",
       "         [-1.8241e-01, -4.0144e-01, -4.5949e-02, -1.4078e-01],\n",
       "         [-4.7487e-01, -3.9788e-01,  1.4886e-03, -2.1920e-01],\n",
       "         [-5.8959e-03, -4.7831e-02, -2.8017e-03,  8.2081e-02],\n",
       "         [-8.6098e-01, -5.9045e-01,  1.3830e-02,  1.7112e-02],\n",
       "         [-1.4835e+00, -7.7596e-01, -4.1906e-02,  9.7851e-02],\n",
       "         [-9.9122e-01, -3.9550e-01, -1.3734e-02, -2.7177e-01],\n",
       "         [-4.6301e-01, -5.9300e-01,  1.9070e-05,  7.3477e-02],\n",
       "         [-1.2575e+00, -5.8534e-01, -3.2248e-03, -9.5488e-02],\n",
       "         [-1.4235e-01, -4.0425e-01, -4.9464e-02, -7.8773e-02],\n",
       "         [-2.5770e-01, -7.8819e-01, -2.6169e-02,  3.6775e-01],\n",
       "         [-1.1326e+00, -7.8186e-01, -1.7892e-02,  2.2806e-01],\n",
       "         [-1.2829e-02,  1.4750e-01,  6.7398e-03, -2.1514e-01],\n",
       "         [-5.3432e-01, -3.9757e-01, -7.4079e-03, -2.2606e-01],\n",
       "         [ 2.0941e-03, -4.7147e-02, -1.3818e-02,  6.6999e-02],\n",
       "         [-1.1599e+00, -7.8142e-01, -1.4735e-02,  2.1824e-01],\n",
       "         [-6.9684e-02, -6.0536e-01, -7.1999e-02,  3.4714e-01],\n",
       "         [-1.5106e+00, -7.7479e-01, -4.4104e-02,  7.2065e-02],\n",
       "         [-1.2302e+00, -5.8552e-01, -5.3836e-03, -9.1518e-02],\n",
       "         [-2.0923e-02,  1.4890e-01,  2.0059e-02, -2.4618e-01],\n",
       "         [-4.4320e-01, -5.9290e-01,  3.0084e-03,  7.1132e-02],\n",
       "         [-1.1052e+00, -7.8241e-01, -2.1507e-02,  2.4008e-01],\n",
       "         [ 1.8147e-02, -2.3496e-01, -5.6939e-02,  1.9882e-01],\n",
       "         [-3.5623e-01, -5.9185e-01,  5.4134e-03,  4.8062e-02],\n",
       "         [-9.3992e-01, -5.9122e-01, -5.0591e-03,  3.4194e-02],\n",
       "         [-8.6399e-03, -2.2715e-01, -6.1812e-02,  2.6131e-02],\n",
       "         [-1.3395e+00, -5.8515e-01,  2.5880e-03, -9.9684e-02],\n",
       "         [ 7.9982e-03, -3.7531e-02, -5.1902e-02, -1.4532e-01],\n",
       "         [-3.4049e-01, -7.8699e-01, -1.4104e-03,  3.4119e-01],\n",
       "         [-3.1291e-01, -7.8716e-01, -9.2929e-03,  3.4486e-01],\n",
       "         [-3.4065e-03, -2.2879e-01, -5.8097e-02,  6.2339e-02],\n",
       "         [-8.7279e-01, -3.9553e-01,  1.4172e-02, -2.7118e-01],\n",
       "         [-9.2017e-01, -5.9123e-01, -5.7537e-04,  3.4340e-02],\n",
       "         [-1.6846e-02,  1.4801e-01,  1.2741e-02, -2.2640e-01],\n",
       "         [-8.8070e-01, -5.9085e-01,  8.7485e-03,  2.5944e-02],\n",
       "         [-7.9810e-01, -7.8452e-01,  7.8431e-03,  2.8653e-01],\n",
       "         [-1.6533e+00, -7.6726e-01, -6.0870e-02, -9.4303e-02],\n",
       "         [-9.5174e-01, -3.9603e-01, -4.3752e-03, -2.6008e-01],\n",
       "         [-1.4719e+00, -5.8140e-01, -3.8255e-02, -1.8252e-01],\n",
       "         [-4.4456e-02, -2.1932e-01, -6.9270e-02, -1.4684e-01],\n",
       "         [-1.4447e+00, -5.8252e-01, -3.7552e-02, -1.5788e-01]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]]),\n",
       " tensor([[-1.1326e+00, -7.8186e-01, -1.7892e-02,  2.2806e-01],\n",
       "         [-1.9130e-03, -4.7644e-02, -8.0441e-03,  7.7944e-02],\n",
       "         [ 4.1145e-03, -4.6772e-02, -1.7078e-02,  5.8710e-02],\n",
       "         [-5.2246e-01, -5.9282e-01, -8.7956e-03,  6.9389e-02],\n",
       "         [-4.0359e-01, -5.9241e-01,  9.5252e-03,  6.0447e-02],\n",
       "         [-2.0004e-02, -4.5940e-02,  1.9252e-02,  4.0363e-02],\n",
       "         [-1.2575e+00, -5.8534e-01, -3.2248e-03, -9.5488e-02],\n",
       "         [-1.2146e+00, -7.8078e-01, -9.4663e-03,  2.0414e-01],\n",
       "         [-8.6270e-04,  1.4784e-01, -9.3632e-03, -2.2272e-01],\n",
       "         [-4.2339e-01, -5.9270e-01,  6.1498e-03,  6.6834e-02],\n",
       "         [-1.0188e+00, -5.8988e-01, -2.4481e-02,  4.5663e-03],\n",
       "         [-6.6849e-01, -3.9497e-01, -8.9765e-03, -2.8349e-01],\n",
       "         [ 1.6727e-02, -4.2043e-02, -4.7450e-02, -4.5654e-02],\n",
       "         [-5.3432e-01, -3.9757e-01, -7.4079e-03, -2.2606e-01],\n",
       "         [-4.4320e-01, -5.9290e-01,  3.0084e-03,  7.1132e-02],\n",
       "         [-1.6240e-01, -4.0282e-01, -4.7081e-02, -1.1043e-01],\n",
       "         [-5.0265e-01, -5.9296e-01, -5.8092e-03,  7.2593e-02],\n",
       "         [-1.4174e+00, -5.8364e-01, -3.7834e-02, -1.3321e-01],\n",
       "         [-2.3005e-01, -7.8909e-01, -3.5602e-02,  3.8767e-01],\n",
       "         [-7.8632e-01, -5.8928e-01,  8.0166e-03, -8.6764e-03],\n",
       "         [ 8.6021e-05, -4.7436e-02, -1.0830e-02,  7.3364e-02],\n",
       "         [-7.4316e-01, -7.8438e-01, -3.1129e-03,  2.8344e-01],\n",
       "         [-4.8614e-03,  1.4742e-01, -3.7731e-03, -2.1355e-01],\n",
       "         [-1.0385e+00, -5.8918e-01, -3.0304e-02, -1.0847e-02],\n",
       "         [-1.9043e-01, -5.9588e-01, -4.8765e-02,  1.3706e-01],\n",
       "         [-4.8283e-01, -5.9302e-01, -2.8954e-03,  7.3953e-02],\n",
       "         [-6.8525e-03,  1.4733e-01, -1.1601e-03, -2.1148e-01],\n",
       "         [-8.7279e-01, -3.9553e-01,  1.4172e-02, -2.7118e-01],\n",
       "         [-1.4990e+00, -5.8026e-01, -3.9949e-02, -2.0775e-01],\n",
       "         [-9.9913e-01, -5.9042e-01, -1.9170e-02,  1.6553e-02],\n",
       "         [-4.7487e-01, -3.9788e-01,  1.4886e-03, -2.1920e-01],\n",
       "         [-1.2693e+00, -7.8042e-01, -5.1345e-03,  1.9618e-01],\n",
       "         [-1.5043e-01, -5.9863e-01, -5.1039e-02,  1.9790e-01],\n",
       "         [-2.7346e-01, -5.9271e-01, -1.8814e-02,  6.6933e-02],\n",
       "         [-1.1482e+00, -5.8649e-01, -1.3331e-02, -7.0211e-02],\n",
       "         [-9.8787e-03, -4.7722e-02,  2.4370e-03,  7.9663e-02],\n",
       "         [-5.4227e-01, -5.9259e-01, -1.1929e-02,  6.4281e-02],\n",
       "         [ 1.1511e-03,  1.4817e-01, -1.2478e-02, -2.3001e-01],\n",
       "         [-1.1755e+00, -5.8609e-01, -1.0370e-02, -7.9058e-02],\n",
       "         [-8.1792e-02, -4.0930e-01, -6.5057e-02,  3.2646e-02],\n",
       "         [-1.5261e+00, -5.7906e-01, -4.2662e-02, -2.3420e-01],\n",
       "         [-1.2419e+00, -7.8057e-01, -7.2140e-03,  1.9946e-01],\n",
       "         [-1.7945e-02, -4.6502e-02,  1.5135e-02,  5.2757e-02],\n",
       "         [-4.5505e-01, -3.9782e-01,  4.4311e-03, -2.2060e-01],\n",
       "         [-1.1208e+00, -5.8698e-01, -1.6706e-02, -5.9305e-02],\n",
       "         [ 1.3448e-02, -3.9074e-02, -5.2962e-02, -1.1126e-01],\n",
       "         [-3.6807e-01, -7.8705e-01,  6.3746e-03,  3.4245e-01],\n",
       "         [-9.5174e-01, -3.9603e-01, -4.3752e-03, -2.6008e-01],\n",
       "         [-1.3183e-02, -3.1195e-02, -6.1289e-02, -2.8540e-01],\n",
       "         [-1.3512e+00, -3.9007e-01,  5.9433e-04, -3.9155e-01],\n",
       "         [ 7.2476e-03, -2.3187e-01, -5.4809e-02,  1.3055e-01],\n",
       "         [-3.5623e-01, -5.9185e-01,  5.4134e-03,  4.8062e-02],\n",
       "         [-3.2865e-01, -5.9191e-01, -2.3957e-03,  4.9263e-02],\n",
       "         [-7.9823e-03, -3.2881e-02, -5.6850e-02, -2.4809e-01],\n",
       "         [-8.8070e-01, -5.9085e-01,  8.7485e-03,  2.5944e-02],\n",
       "         [-9.3200e-01, -3.9610e-01,  1.1143e-04, -2.5852e-01],\n",
       "         [-1.3886e-02, -4.7296e-02,  8.2127e-03,  7.0271e-02],\n",
       "         [-8.9251e-01, -3.9585e-01,  9.2674e-03, -2.6397e-01],\n",
       "         [-8.1379e-01, -5.8951e-01,  1.3574e-02, -3.6739e-03],\n",
       "         [-1.6686e+00, -9.6146e-01, -6.2756e-02,  1.7857e-01],\n",
       "         [-9.5967e-01, -5.9109e-01, -9.5768e-03,  3.1219e-02],\n",
       "         [-1.4835e+00, -7.7596e-01, -4.1906e-02,  9.7851e-02],\n",
       "         [-4.8843e-02, -4.1339e-01, -7.2207e-02,  1.2321e-01],\n",
       "         [-1.4563e+00, -7.7708e-01, -4.0710e-02,  1.2272e-01]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 4]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 4]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 4] -> [b, 2]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 2] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 4] -> [b, 2]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 2] -> [b, 1]\n",
    "    target = target.max(dim=1)[0]\n",
    "    target = target.reshape(-1, 1)\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 200 0 9.4\n",
      "50 10000 200 200 200.0\n",
      "100 10000 200 200 200.0\n",
      "150 10000 200 200 200.0\n",
      "200 10000 384 384 185.9\n",
      "250 10000 375 375 176.9\n",
      "300 10000 200 200 200.0\n",
      "350 10000 200 200 186.05\n",
      "400 10000 200 200 200.0\n",
      "450 10000 200 200 200.0\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(500):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 10 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 50 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAASdUlEQVR4nO3df6zV933f8efLmGArsWxTX1uEH4UlVJrxVlxdsUjuNuZkNfWWkkjzRKpGSLVE/nCkRKuy2a3WJn8gdVPj7J8lElmsoiwNoUo8IytdS5mzKFJqgmPsgDE1NSi+gQAm8/yDiF9+74/7ZT6Fe7mH+yOXzz3Ph3R0vud9Pt9z3m/L98WXL99zT6oKSVI7rpvtBiRJV8fglqTGGNyS1BiDW5IaY3BLUmMMbklqzIwFd5J1SQ4mOZTk4Zl6H0kaNJmJ67iTzAP+FviXwAjwA+BjVfXCtL+ZJA2YmTriXgMcqqqXq+ossA1YP0PvJUkD5foZet3FwCs9j0eAfzLe4ttuu62WL18+Q61IUnuOHDnCq6++mrGem6ngHuvN/t45mSSbgE0Ay5YtY8+ePTPUiiS1Z3h4eNznZupUyQiwtOfxEuBo74Kq2lJVw1U1PDQ0NENtSNLcM1PB/QNgZZIVSd4FbAB2zNB7SdJAmZFTJVV1Pskngb8E5gGPVdX+mXgvSRo0M3WOm6r6NvDtmXp9SRpUfnJSkhpjcEtSYwxuSWqMwS1JjTG4JakxBrckNcbglqTGGNyS1BiDW5IaY3BLUmMMbklqjMEtSY0xuCWpMQa3JDXG4JakxhjcktQYg1uSGmNwS1JjpvTVZUmOAG8AF4DzVTWcZCHwDWA5cAT4t1X1f6bWpiTpouk44v4XVbW6qoa7xw8Du6pqJbCreyxJmiYzcapkPbC1294KfGQG3kOSBtZUg7uAv0ryTJJNXe2OqjoG0N3fPsX3kCT1mNI5buCeqjqa5HZgZ5IX+92xC/pNAMuWLZtiG5I0OKZ0xF1VR7v7E8DjwBrgeJJFAN39iXH23VJVw1U1PDQ0NJU2JGmgTDq4k7w7yU0Xt4HfAPYBO4CN3bKNwBNTbVKS9I6pnCq5A3g8ycXX+bOq+p9JfgBsT/Ig8GPggam3KUm6aNLBXVUvA786Rv0U8MGpNCVJGp+fnJSkxhjcktQYg1uSGmNwS1JjDG5JaozBLUmNMbglqTEGtyQ1xuCWpMYY3JLUGINbkhpjcEtSYwxuSWqMwS1JjTG4JakxBrckNcbglqTGGNyS1BiDW5IaM2FwJ3ksyYkk+3pqC5PsTPJSd39rz3OPJDmU5GCS+2aqcUkaVP0ccf8psO6S2sPArqpaCezqHpPkTmADsKrb54tJ5k1bt5KkiYO7qr4L/OyS8npga7e9FfhIT31bVZ2pqsPAIWDN9LQqSYLJn+O+o6qOAXT3t3f1xcArPetGutplkmxKsifJnpMnT06yDUkaPNP9j5MZo1ZjLayqLVU1XFXDQ0ND09yGJM1dkw3u40kWAXT3J7r6CLC0Z90S4Ojk25MkXWqywb0D2NhtbwSe6KlvSLIgyQpgJbB7ai1KknpdP9GCJF8H1gK3JRkB/gj4Y2B7kgeBHwMPAFTV/iTbgReA88BDVXVhhnqXpIE0YXBX1cfGeeqD46zfDGyeSlOSpPH5yUlJaozBLUmNMbglqTEGtyQ1xuCWpMYY3JLUGINbkhpjcEtSYwxuSWqMwS1JjTG4JakxBrckNcbglqTGGNyS1BiDW5IaY3BLUmMMbklqjMEtSY2ZMLiTPJbkRJJ9PbXPJvlJkr3d7f6e5x5JcijJwST3zVTjkjSo+jni/lNg3Rj1L1TV6u72bYAkdwIbgFXdPl9MMm+6mpUk9RHcVfVd4Gd9vt56YFtVnamqw8AhYM0U+pMkXWIq57g/meT57lTKrV1tMfBKz5qRrnaZJJuS7Emy5+TJk1NoQ5IGy2SD+0vA+4DVwDHg8109Y6ytsV6gqrZU1XBVDQ8NDU2yDUkaPJMK7qo6XlUXqupt4Mu8czpkBFjas3QJcHRqLUqSek0quJMs6nn4UeDiFSc7gA1JFiRZAawEdk+tRUlSr+snWpDk68Ba4LYkI8AfAWuTrGb0NMgR4BMAVbU/yXbgBeA88FBVXZiRziVpQE0Y3FX1sTHKX7nC+s3A5qk0JUkan5+clKTGGNyS1BiDW5IaY3BLUmMMbklqzIRXlUjq34VzZ3jrxOHL6tfNX8C7h5aTjPXhYunqGNzSNDr7xikOPvkFLv1NDzfe+l5WPfCHjP1bIaSr46kSSWqMwS1JjTG4JakxBrckNcbglqTGGNyS1BiDW5IaY3BLUmMMbklqjMEtSY0xuCWpMRMGd5KlSZ5KciDJ/iSf6uoLk+xM8lJ3f2vPPo8kOZTkYJL7ZnIASRo0/Rxxnwd+r6r+IfAB4KEkdwIPA7uqaiWwq3tM99wGYBWwDvhiknkz0bwkDaIJg7uqjlXVD7vtN4ADwGJgPbC1W7YV+Ei3vR7YVlVnquowcAhYM819S9LAuqpz3EmWA3cDTwN3VNUxGA134PZu2WLglZ7dRrrapa+1KcmeJHtOnjw5idYlaTD1HdxJ3gN8E/h0Vb1+paVj1OqyQtWWqhququGhoaF+25CkgddXcCeZz2hof62qvtWVjydZ1D2/CDjR1UeApT27LwGOTk+7kqR+rioJ8BXgQFU92vPUDmBjt70ReKKnviHJgiQrgJXA7ulrWZIGWz9fXXYP8HHgR0n2drXfB/4Y2J7kQeDHwAMAVbU/yXbgBUavSHmoqi5Md+OSNKgmDO6q+h7jf1HeB8fZZzOweQp9SZLG4ScnJakxBrckNcbglqTGGNyS1BiDW5IaY3BLUmMMbklqjMEtSY0xuCWpMQa3JDXG4JakxhjcktQYg1uSGmNwS1JjDG5JaozBLUmNMbglqTEGtyQ1pp8vC16a5KkkB5LsT/Kprv7ZJD9Jsre73d+zzyNJDiU5mOS+mRxAkgZNP18WfB74var6YZKbgGeS7Oye+0JV/Unv4iR3AhuAVcB7gb9O8it+YbAkTY8Jj7ir6lhV/bDbfgM4ACy+wi7rgW1VdaaqDgOHgDXT0awk6SrPcSdZDtwNPN2VPpnk+SSPJbm1qy0GXunZbYQrB70k6Sr0HdxJ3gN8E/h0Vb0OfAl4H7AaOAZ8/uLSMXavMV5vU5I9SfacPHnyavuWpIHVV3Anmc9oaH+tqr4FUFXHq+pCVb0NfJl3ToeMAEt7dl8CHL30NatqS1UNV9Xw0NDQVGaQpIHSz1UlAb4CHKiqR3vqi3qWfRTY123vADYkWZBkBbAS2D19LUvSYOvnqpJ7gI8DP0qyt6v9PvCxJKsZPQ1yBPgEQFXtT7IdeIHRK1Ie8ooSSZo+EwZ3VX2Psc9bf/sK+2wGNk+hL0nSOPzkpCQ1xuCWpMYY3JLUGINbkhpjcEtSYwxuSWqMwS1JjTG4JakxBrckNcbglqTGGNyS1BiDW5Ia089vB5QG2tmzZ3n22We5cGHiX3KZM68xj7rst7KdPn2a73//b/p6v5tvvplVq1ZNolMNCoNbmsCpU6e49957OX369IRrVyy6ha//x39Drvv70f3y4Zf57d+9h7rsu6Aut3btWp566qnJtqsBYHBL06wIp84u4sTZX+b6nGXxDS8BP5vttjSHGNzStApHz7yfA2/9Uy50P17Hzr6fm89/Y5b70lziP05K0+itCzez/81f5wLzGf3+kfDWhVv40Zv/nKqxvo9EunoGtzSNiut4e4y/yF6o+bPQjeaqfr4s+IYku5M8l2R/ks919YVJdiZ5qbu/tWefR5IcSnIwyX0zOYB0LZnHed6Vn19Wv+G6N2ehG81V/RxxnwHurapfBVYD65J8AHgY2FVVK4Fd3WOS3AlsAFYB64AvJpk3A71L15wb573O6pt2ceN1bwBvcx3n+aX5I/yjm/43SR+XlEh96OfLggu4eLgwv7sVsB5Y29W3At8B/kNX31ZVZ4DDSQ4Ba4Dvj/ce586d46c//enkJpBm2IkTJ6h+ruMDXv2/p9ny53/O6bf/ktfO3cF1Oc9t83/Cz3/+Rl+XAsLodeP+POjcuXPjPtfXVSXdEfMzwPuB/1pVTye5o6qOAVTVsSS3d8sXA72fNBjpauM6deoUX/3qV/tpRfqFe/311zl//nxfa984fZb/8b0Xp/R+x48f9+dBnDp1atzn+gruqroArE5yC/B4kruusHysfzq/7FgjySZgE8CyZcv4zGc+008r0i/csWPHePTRR694BDSdli5d6s+D+MY3xr+E9KquKqmq1xg9JbIOOJ5kEUB3f6JbNgIs7dltCXB0jNfaUlXDVTU8NDR0NW1I0kDr56qSoe5ImyQ3Ah8CXgR2ABu7ZRuBJ7rtHcCGJAuSrABWArunuW9JGlj9nCpZBGztznNfB2yvqieTfB/YnuRB4MfAAwBVtT/JduAF4DzwUHeqRZI0Dfq5quR54O4x6qeAD46zz2Zg85S7kyRdxk9OSlJjDG5Jaoy/HVCawIIFC/jwhz/MmTNnfiHvd9ddV7raVjK4pQktXLiQbdu2zXYb0v/nqRJJaozBLUmNMbglqTEGtyQ1xuCWpMYY3JLUGINbkhpjcEtSYwxuSWqMwS1JjTG4JakxBrckNcbglqTGGNyS1Jh+viz4hiS7kzyXZH+Sz3X1zyb5SZK93e3+nn0eSXIoycEk983kAJI0aPr5fdxngHur6s0k84HvJfmL7rkvVNWf9C5OciewAVgFvBf46yS/4hcGS9L0mPCIu0a92T2c393qCrusB7ZV1ZmqOgwcAtZMuVNJEtDnOe4k85LsBU4AO6vq6e6pTyZ5PsljSW7taouBV3p2H+lqkqRp0FdwV9WFqloNLAHWJLkL+BLwPmA1cAz4fLc8Y73EpYUkm5LsSbLn5MmTk2hdkgbTVV1VUlWvAd8B1lXV8S7Q3wa+zDunQ0aApT27LQGOjvFaW6pquKqGh4aGJtO7JA2kfq4qGUpyS7d9I/Ah4MUki3qWfRTY123vADYkWZBkBbAS2D2tXUvSAOvnqpJFwNYk8xgN+u1V9WSSryZZzehpkCPAJwCqan+S7cALwHngIa8okaTpM2FwV9XzwN1j1D9+hX02A5un1pokaSx+clKSGmNwS1JjDG5JaozBLUmNMbglqTEGtyQ1xuCWpMYY3JLUGINbkhpjcEtSYwxuSWqMwS1JjTG4JakxBrckNcbglqTGGNyS1BiDW5IaY3BLUmMMbklqjMEtSY0xuCWpMQa3JDUmVTXbPZDkJPAW8Ops9zIDbsO5WjNXZ3OutvxyVQ2N9cQ1EdwASfZU1fBs9zHdnKs9c3U255o7PFUiSY0xuCWpMddScG+Z7QZmiHO1Z67O5lxzxDVzjluS1J9r6YhbktSHWQ/uJOuSHExyKMnDs93P1UryWJITSfb11BYm2Znkpe7+1p7nHulmPZjkvtnpemJJliZ5KsmBJPuTfKqrNz1bkhuS7E7yXDfX57p603NdlGRekmeTPNk9nitzHUnyoyR7k+zpanNitkmpqlm7AfOAvwP+AfAu4DngztnsaRIz/DPg14B9PbX/DDzcbT8M/Kdu+85uxgXAim72ebM9wzhzLQJ+rdu+Cfjbrv+mZwMCvKfbng88DXyg9bl65vt3wJ8BT86V/xe7fo8At11SmxOzTeY220fca4BDVfVyVZ0FtgHrZ7mnq1JV3wV+dkl5PbC1294KfKSnvq2qzlTVYeAQo/8NrjlVdayqfthtvwEcABbT+Gw16s3u4fzuVjQ+F0CSJcC/Av5bT7n5ua5gLs92RbMd3IuBV3oej3S11t1RVcdgNACB27t6k/MmWQ7czejRafOzdacT9gIngJ1VNSfmAv4L8O+Bt3tqc2EuGP3D9a+SPJNkU1ebK7Ndtetn+f0zRm0uX+bS3LxJ3gN8E/h0Vb2ejDXC6NIxatfkbFV1AVid5Bbg8SR3XWF5E3Ml+dfAiap6JsnafnYZo3bNzdXjnqo6muR2YGeSF6+wtrXZrtpsH3GPAEt7Hi8Bjs5SL9PpeJJFAN39ia7e1LxJ5jMa2l+rqm915TkxG0BVvQZ8B1hH+3PdA/xWkiOMnnK8N8l/p/25AKiqo939CeBxRk99zInZJmO2g/sHwMokK5K8C9gA7JjlnqbDDmBjt70ReKKnviHJgiQrgJXA7lnob0IZPbT+CnCgqh7tearp2ZIMdUfaJLkR+BDwIo3PVVWPVNWSqlrO6M/R/6qq36HxuQCSvDvJTRe3gd8A9jEHZpu02f7XUeB+Rq9Y+DvgD2a7n0n0/3XgGHCO0T/pHwR+CdgFvNTdL+xZ/wfdrAeB35zt/q8w168z+tfL54G93e3+1mcD/jHwbDfXPuAPu3rTc10y41reuaqk+bkYversue62/2JOzIXZJnvzk5OS1JjZPlUiSbpKBrckNcbglqTGGNyS1BiDW5IaY3BLUmMMbklqjMEtSY35f9dpguhUTxhhAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
